// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
	"bytes"
	"errors"
	"flag"
	"fmt"
	"go/format"
	"io"
	"os"
	"path"
	"strings"

	"gopkg.in/yaml.v3"

	"github.com/dolthub/go-mysql-server/enginetest"
	"github.com/dolthub/go-mysql-server/enginetest/queries"
	"github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup"
	"github.com/dolthub/go-mysql-server/sql"
	"github.com/dolthub/go-mysql-server/sql/planbuilder"
)

//go:generate go run ./main.go -srcRoot=../../../../ plan ../../testdata/spec.yaml

var (
	errInvalidArgCount     = errors.New("invalid number of arguments")
	errUnrecognizedCommand = errors.New("unrecognized command")
)

var (
	pkg     = flag.String("pkg", "queries", "package name used in generated files")
	srcRoot = flag.String("srcRoot", "", "path to package root")
)

const useGoFmt = true

func main() {
	flag.Usage = usage
	flag.Parse()

	args := flag.Args()
	if len(args) < 2 {
		flag.Usage()
		exit(errInvalidArgCount)
	}

	cmd := args[0]
	switch cmd {
	case "plan":

	default:
		flag.Usage()
		exit(errUnrecognizedCommand)
	}

	err := generatePlans(args[1], *srcRoot)
	if err != nil {
		exit(err)
	}
}

type PlanSpecs struct {
	Plans []PlanSpec `yaml:"plans"`
}

type PlanSpec struct {
	Name string `yaml:"name"`
	Path string `yaml:"path"`
}

func ParseSpec(path string) (PlanSpecs, error) {
	contents, err := os.ReadFile(path)
	if err != nil {
		return PlanSpecs{}, err
	}
	dec := yaml.NewDecoder(bytes.NewReader(contents))
	dec.KnownFields(true)
	var res PlanSpecs
	err = dec.Decode(&res)
	return res, err
}

func writeHeader(buf *bytes.Buffer, pkg string) {
	_, _ = fmt.Fprint(buf, "// Code generated by plangen.\n\n")
	_, _ = fmt.Fprint(buf, `// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.`)
	_, _ = fmt.Fprint(buf, "\n\n")
	_, _ = fmt.Fprintf(buf, "package %s\n\n", pkg)
}

func generatePlans(specPath string, srcRoot string) error {
	specs, err := ParseSpec(specPath)
	if err != nil {
		exit(err)
	}
	for _, spec := range specs.Plans {
		var buf bytes.Buffer
		writeHeader(&buf, *pkg)
		if spec.Name == "QueryPlanScriptTests" {
			_, _ = fmt.Fprint(&buf, "import (\n\t\"github.com/dolthub/go-mysql-server/sql\"\n)\n\n")
			err = generatePlansForScriptSuite(spec, &buf)
		} else {
			err = generatePlansForSuite(spec, &buf)
		}
		if err != nil {
			exit(err)
		}
		out := path.Join(srcRoot, spec.Path)
		err = toFile(buf, out)
		if err != nil {
			exit(err)
		}
	}
	return nil
}

func writePlanString(w *bytes.Buffer, planString string) {
	for i, line := range strings.Split(planString, "\n") {
		if i > 0 {
			_, _ = w.WriteString(" + \n")
		}
		if len(line) > 0 {
			_, _ = w.WriteString(fmt.Sprintf(`"%s\n"`, strings.ReplaceAll(line, `"`, `\"`)))
		} else {
			// final line with comma
			_, _ = w.WriteString("\"\",\n")
		}
	}
}

func analyzeQuery(ctx *sql.Context, engine enginetest.QueryEngine, query string) sql.Node {
	binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, engine.EngineEventScheduler(), nil)
	parsed, _, _, qFlags, err := binder.Parse(query, nil, false)
	if err != nil {
		exit(fmt.Errorf("%w\nfailed to parse query: %s", err, query))
	}
	node, err := engine.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
	if err != nil {
		exit(fmt.Errorf("%w\nfailed to analyze query: %s", err, query))
	}
	return node
}

func generatePlansForSuite(spec PlanSpec, w *bytes.Buffer) error {
	harness := enginetest.NewMemoryHarness("default", 1, 1, true, nil)
	s := specSetup(spec.Name)
	harness.Setup(s...)
	engine, err := harness.NewEngine(nil)
	if err != nil {
		exit(err)
	}

	queries := specQueries(spec.Name)

	_, _ = fmt.Fprintf(w, "var %s = []QueryPlanTest{\n", spec.Name)
	for _, tt := range queries {
		_, _ = w.WriteString("\t{\n")

		if strings.Contains(tt.Query, "`") {
			_, _ = w.WriteString("Query: ")
			for i, line := range strings.Split(strings.TrimSpace(tt.Query), "\n") {
				if i > 0 {
					_, _ = w.WriteString(" + \n")
				}
				if len(line) > 0 {
					_, _ = w.WriteString(fmt.Sprintf(`"%s\n"`, strings.ReplaceAll(line, `"`, `\"`)))
				}
			}
			// final line with comma
			_, _ = w.WriteString(" + \n\"\",\n")
		} else {
			_, _ = w.WriteString(fmt.Sprintf("Query: `%s`,", tt.Query))
		}
		_, _ = w.WriteString("\n")

		if !tt.Skip {
			ctx := enginetest.NewContext(harness)
			node := analyzeQuery(ctx, engine, tt.Query)
			_, _ = w.WriteString(`ExpectedPlan: `)
			planString := sql.Describe(enginetest.ExtractQueryNode(node), sql.DescribeOptions{
				Debug: true,
			})
			writePlanString(w, planString)

			if node.IsReadOnly() {
				var planString string
				if tt.ExpectedEstimates != "skip" {
					_, _ = w.WriteString(`ExpectedEstimates: `)
					planString = sql.Describe(enginetest.ExtractQueryNode(node), sql.DescribeOptions{
						Estimates: true,
					})
					writePlanString(w, planString)
				} else {
					_, _ = w.WriteString("ExpectedEstimates: \"skip\",\n")
				}

				if tt.ExpectedAnalysis != "skip" {
					_, _ = w.WriteString(`ExpectedAnalysis: `)
					err = enginetest.ExecuteNode(ctx, engine, node)
					if err != nil {
						exit(fmt.Errorf("%w\nfailed to execute query: %s", err, tt.Query))
					}
					planString = sql.Describe(enginetest.ExtractQueryNode(node), sql.DescribeOptions{
						Analyze:   true,
						Estimates: true,
					})
					writePlanString(w, planString)
				} else {
					_, _ = w.WriteString("ExpectedAnalysis: \"skip\",\n")
				}
			}
		} else {
			_, _ = w.WriteString(`Skip: true,\n`)
		}

		_, _ = w.WriteString("\t},\n")
	}
	_, _ = w.WriteString("}")

	return nil
}

func generatePlansForScriptSuite(spec PlanSpec, w *bytes.Buffer) error {
	harness := enginetest.NewMemoryHarness("default", 1, 1, true, nil)
	harness.Setup(setup.MydbData)
	_, _ = fmt.Fprintf(w, "var %s = []ScriptTest{\n", spec.Name)
	for _, tt := range queries.QueryPlanScriptTests {
		w.WriteString("\t{\n")
		if tt.Dialect != "" {
			w.WriteString(fmt.Sprintf("\t\tDialect: \"%s\",\n", tt.Dialect))
		}
		w.WriteString(fmt.Sprintf("\t\tName: \"%s\",\n", tt.Name))
		w.WriteString("\t\tSetUpScript: []string{\n")
		for _, setupQuery := range tt.SetUpScript {
			w.WriteString(fmt.Sprintf("\t\t\t\"%s\",\n", setupQuery))
		}
		w.WriteString("\t\t},\n")
		w.WriteString("\t\tAssertions: []ScriptTestAssertion{\n")
		for _, assertion := range tt.Assertions {
			w.WriteString("\t\t\t{\n")
			if assertion.Skip {
				w.WriteString("\t\t\t\tSkip: true,\n")
			}
			w.WriteString(fmt.Sprintf("\t\t\t\tQuery: \"%s\",\n", assertion.Query))
			w.WriteString(fmt.Sprintf("\t\t\t\tExpected: []sql.Row{\n"))
			for _, expRow := range assertion.Expected {
				w.WriteString(fmt.Sprintf("\t\t\t\t\t%#v,\n", expRow))
			}
			w.WriteString(fmt.Sprintf("\t\t\t\t},\n"))
			if assertion.Skip {
				w.WriteString("\t\t\t},\n")
				continue
			}

			engine, err := harness.NewEngine(nil)
			if err != nil {
				exit(err)
			}
			ctx := enginetest.NewContext(harness)
			for _, setupQuery := range tt.SetUpScript {
				ctx = ctx.WithQuery(setupQuery)
				_, iter, _, err := engine.Query(ctx, setupQuery)
				if err != nil {
					exit(fmt.Errorf("%w\nfailed to execute setup query: %s", err, setupQuery))
				}
				_, err = sql.RowIterToRows(ctx, iter)
				if err != nil {
					exit(fmt.Errorf("%w\nfailed to execute setup query: %s", err, setupQuery))
				}
			}

			node := analyzeQuery(ctx, engine, assertion.Query)
			w.WriteString("\t\t\t\tExpectedPlan: ")
			planString := sql.Describe(enginetest.ExtractQueryNode(node), sql.DescribeOptions{
				Debug: true,
			})
			writePlanString(w, planString)
			w.WriteString("\t\t\t},\n")
		}
		w.WriteString("\t\t},\n")
		w.WriteString("\t},\n")
	}
	w.WriteString("}")

	return nil
}

func specSetup(name string) [][]setup.SetupScript {
	switch name {
	case "PlanTests":
		return setup.PlanSetup
	case "IndexPlanTests":
		return setup.ComplexIndexSetup
	case "ImdbPlanTests":
		return setup.ImdbPlanSetup
	case "TpchPlanTests":
		return setup.TpchPlanSetup
	case "TpcdsPlanTests":
		return setup.TpcdsPlanSetup
	case "IntegrationPlanTests":
		return setup.IntegrationPlanSetup
	case "TpccPlanTests":
		return setup.TpccPlanSetup
	case "GeneratedColumnPlanTests":
		return setup.GeneratedColumnSetup
	case "SysbenchPlanTests":
		return setup.SysbenchSetup
	default:
		exit(fmt.Errorf("setup not found for plan suite: %s", name))
		return nil
	}
}

func specQueries(name string) []queries.QueryPlanTest {
	switch name {
	case "PlanTests":
		return queries.PlanTests
	case "IndexPlanTests":
		return queries.IndexPlanTests
	case "ImdbPlanTests":
		return queries.ImdbPlanTests
	case "TpchPlanTests":
		return queries.TpchPlanTests
	case "TpccPlanTests":
		return queries.TpccPlanTests
	case "TpcdsPlanTests":
		return queries.TpcdsPlanTests
	case "IntegrationPlanTests":
		return queries.IntegrationPlanTests
	case "GeneratedColumnPlanTests":
		return queries.GeneratedColumnPlanTests
	case "SysbenchPlanTests":
		return queries.SysbenchPlanTests
	default:
		exit(fmt.Errorf("queries not found for plan suite: %s", name))
		return nil
	}
}

// usage is a replacement usage function for the flags package.
func usage() {
	fmt.Fprintf(os.Stderr, "Plangen generates expected plan tests .\n\n")
	fmt.Fprintf(os.Stderr, "Usage:\n")

	fmt.Fprintf(os.Stderr, "\tplangen [flags] spec\n\n")

	//fmt.Fprintf(os.Stderr, "Flags:\n")

	flag.PrintDefaults()

	fmt.Fprintf(os.Stderr, "\n")
}

func exit(err error) {
	fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
	os.Exit(2)
}

func toFile(buf bytes.Buffer, out string) error {
	var w io.Writer
	if out != "" {
		file, err := os.Create(out)
		if err != nil {
			exit(err)
		}

		defer file.Close()
		w = file
	} else {
		w = os.Stderr
	}

	var b []byte
	var err error

	if useGoFmt {
		b, err = format.Source(buf.Bytes())
		if err != nil {
			// Write out incorrect source for easier debugging.
			b = buf.Bytes()
		}
	} else {
		b = buf.Bytes()
	}

	w.Write(b)
	return err
}
